{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LTj2A0RUQFNo"
   },
   "source": [
    "# Convert our BiRefNet weights to onnx format.\n",
    "\n",
    "> This colab file is modified from [Kazuhito00](https://github.com/Kazuhito00)'s nice work.\n",
    "\n",
    "> Repo: https://github.com/Kazuhito00/BiRefNet-ONNX-Sample  \n",
    "> Original Colab: https://colab.research.google.com/github/Kazuhito00/BiRefNet-ONNX-Sample/blob/main/Convert2ONNX.ipynb\n",
    "\n",
    "+ Currently, Colab with 12.7GB RAM / 15GB GPU Mem cannot hold the transformation of BiRefNet in default setting. So, I take BiRefNet with swin_v1_tiny backbone as an example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Online Colab version: https://colab.research.google.com/drive/1z6OruR52LOvDDpnp516F-N4EyPGrp5om"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "781JHjLJmveh"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "weights_file = 'BiRefNet-general-bb_swin_v1_tiny-epoch_232.pth'  # https://github.com/ZhengPeng7/BiRefNet/releases/download/v1/BiRefNet-general-bb_swin_v1_tiny-epoch_232.pth\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('config.py') as fp:\n",
    "    file_lines = fp.read()\n",
    "if 'swin_v1_tiny' in weights_file:\n",
    "    print('Set `swin_v1_tiny` as the backbone.')\n",
    "    file_lines = file_lines.replace(\n",
    "        '''\n",
    "            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5\n",
    "        ][6]\n",
    "        ''',\n",
    "        '''\n",
    "            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5\n",
    "        ][3]\n",
    "        ''',\n",
    "    )\n",
    "    with open('config.py', mode=\"w\") as fp:\n",
    "        fp.write(file_lines)\n",
    "else:\n",
    "    file_lines = file_lines.replace(\n",
    "        '''\n",
    "            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5\n",
    "        ][3]\n",
    "        ''',\n",
    "        '''\n",
    "            'pvt_v2_b2', 'pvt_v2_b5',               # 9-bs10, 10-bs5\n",
    "        ][6]\n",
    "        ''',\n",
    "    )\n",
    "    with open('config.py', mode=\"w\") as fp:\n",
    "        fp.write(file_lines)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7lFgKfPS8Icy"
   },
   "outputs": [],
   "source": [
    "from utils import check_state_dict\n",
    "from models.birefnet import BiRefNet\n",
    "\n",
    "\n",
    "birefnet = BiRefNet(bb_pretrained=False)\n",
    "state_dict = torch.load('./{}'.format(weights_file), map_location=device)\n",
    "state_dict = check_state_dict(state_dict)\n",
    "birefnet.load_state_dict(state_dict)\n",
    "\n",
    "torch.set_float32_matmul_precision(['high', 'highest'][0])\n",
    "\n",
    "birefnet.to(device)\n",
    "_ = birefnet.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JVgJAdgxQVJW"
   },
   "source": [
    "# Process deform_conv2d in the conversion to ONNX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vJiZv0L75kTe"
   },
   "outputs": [],
   "source": [
    "from torchvision.ops.deform_conv import DeformConv2d\n",
    "import deform_conv2d_onnx_exporter\n",
    "\n",
    "# register deform_conv2d operator\n",
    "deform_conv2d_onnx_exporter.register_deform_conv2d_onnx_op()\n",
    "\n",
    "def convert_to_onnx(net, file_name='output.onnx', input_shape=(1024, 1024), device=device):\n",
    "    input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)\n",
    "\n",
    "    input_layer_names = ['input_image']\n",
    "    output_layer_names = ['output_image']\n",
    "\n",
    "    torch.onnx.export(\n",
    "        net,\n",
    "        input,\n",
    "        file_name,\n",
    "        verbose=False,\n",
    "        opset_version=17,\n",
    "        input_names=input_layer_names,\n",
    "        output_names=output_layer_names,\n",
    "    )\n",
    "convert_to_onnx(birefnet, weights_file.replace('.pth', '.onnx'), input_shape=(1024, 1024), device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-eU-g40P1zS-"
   },
   "source": [
    "# Load ONNX weights and do the inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LZ4HVqcoDvto"
   },
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "from torchvision import transforms\n",
    "\n",
    "\n",
    "transform_image = transforms.Compose([\n",
    "    transforms.Resize((1024, 1024)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "imagepath = './Helicopter-HR.jpg'\n",
    "image = Image.open(imagepath)\n",
    "input_images = transform_image(image).unsqueeze(0).to(device)\n",
    "input_images_numpy = input_images.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rwzdKX1EfYkd"
   },
   "outputs": [],
   "source": [
    "import onnxruntime\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']\n",
    "onnx_session = onnxruntime.InferenceSession(\n",
    "    weights_file.replace('.pth', '.onnx'),\n",
    "    providers=providers\n",
    ")\n",
    "input_name = onnx_session.get_inputs()[0].name\n",
    "print(onnxruntime.get_device(), onnx_session.get_providers())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DJVtxZUZum4-"
   },
   "outputs": [],
   "source": [
    "from time import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "time_st = time()\n",
    "pred_onnx = torch.tensor(\n",
    "    onnx_session.run(None, {input_name: input_images_numpy if device == 'cpu' else input_images_numpy})[-1]\n",
    ").squeeze(0).sigmoid().cpu()\n",
    "print(time() - time_st)\n",
    "\n",
    "plt.imshow(pred_onnx.squeeze(), cmap='gray'); plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    preds = birefnet(input_images)[-1].sigmoid().cpu()\n",
    "plt.imshow(preds.squeeze(), cmap='gray'); plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diff = abs(preds - pred_onnx)\n",
    "print('sum(diff):', diff.sum())\n",
    "plt.imshow((diff).squeeze(), cmap='gray'); plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qzYHflt92Bjd"
   },
   "source": [
    "# Efficiency Comparison between .pth and .onnx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "A5IYfT-uzphA",
    "outputId": "2999e345-950e-41b3-ddd3-9f58a71a3f21"
   },
   "outputs": [],
   "source": [
    "%%timeit\n",
    "with torch.no_grad():\n",
    "    preds = birefnet(input_images)[-1].sigmoid().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "G0Ul4rfNg1za"
   },
   "outputs": [],
   "source": [
    "%%timeit\n",
    "pred_onnx = torch.tensor(\n",
    "    onnx_session.run(None, {input_name: input_images_numpy})[-1]\n",
    ").squeeze(0).sigmoid().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
